{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5372055b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jupyterthemes.stylefx import set_nb_theme\n",
    "set_nb_theme('chesterish')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11214a4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f45eb6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import datasets \n",
    "\n",
    "from datasets import load_dataset, load_metric\n",
    "\n",
    "from transformers import (\n",
    "    AutoModel,\n",
    "    AutoModelForMaskedLM,\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForTokenClassification,\n",
    "    AutoTokenizer,\n",
    "    DataCollatorForSeq2Seq,\n",
    "    pipeline,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    Seq2SeqTrainer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d26af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the pre-trained model and tokenizer\n",
    "model_name = \"t5-small\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363045f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(batch):\n",
    "    inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
    "    outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
    "    batch[\"input_ids\"] = inputs.input_ids\n",
    "    batch[\"attention_mask\"] = inputs.attention_mask\n",
    "    batch[\"labels\"] = outputs.input_ids.copy()\n",
    "    return batch\n",
    "\n",
    "# Load the dataset\n",
    "train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
    "val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
    "\n",
    "train_ds = train_data.map(\n",
    "    preprocess_function, \n",
    "    batched=True, \n",
    "    batch_size=256, \n",
    "    remove_columns=[\"article\", \"highlights\", \"id\"]\n",
    ")\n",
    "\n",
    "val_ds = val_data.map(\n",
    "    preprocess_function, \n",
    "    batched=True, \n",
    "    batch_size=256, \n",
    "    remove_columns=[\"article\", \"highlights\", \"id\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d58818f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLightningModule(pl.LightningModule):\n",
    "    def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
    "        super().__init__()\n",
    "        self.model_name = model_name\n",
    "        self.learning_rate = learning_rate\n",
    "        self.weight_decay = weight_decay\n",
    "        self.batch_size = batch_size\n",
    "        self.num_training_steps = num_training_steps\n",
    "        \n",
    "        # Load the pre-trained model and tokenizer\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
    "        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        output = self.model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels,\n",
    "        )\n",
    "        return output.loss, output.logits\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        input_ids = batch[\"input_ids\"]\n",
    "        attention_mask = batch[\"attention_mask\"]\n",
    "        labels = batch[\"labels\"]\n",
    "        \n",
    "        loss\n",
    "\n",
    "# Define the data collator\n",
    "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
    "\n",
    "# Initialize the trainer arguments\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    learning_rate=1e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    max_steps=5000,\n",
    "    weight_decay=1e-4,\n",
    "    push_to_hub=False,\n",
    "    evaluation_strategy = \"steps\",\n",
    "    eval_steps = 50,\n",
    "    generation_max_length=128,\n",
    "    predict_with_generate=True,\n",
    "    logging_steps=100,\n",
    "    gradient_accumulation_steps=1,\n",
    "    fp16=True,\n",
    ")\n",
    "\n",
    "# Load the ROUGE metric\n",
    "metric = load_metric(\"rouge\")\n",
    "\n",
    "# Define the evaluation function\n",
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions\n",
    "    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "    scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
    "    return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
    "\n",
    "\n",
    "# Initialize the trainer\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_data,\n",
    "    eval_dataset=val_data,\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "# Start the training\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5148159b",
   "metadata": {},
   "source": [
    "# Steps:\n",
    "1. Rewrite code to be more general\n",
    "\n",
    "a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
    "\n",
    "b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95e33e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0348c2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
